# mechanisms.py
"""
Defines the planner mechanism functions.
Each function encapsulates the logic for a specific planner type.
"""

import math
import numpy as np
import scipy.optimize
import config # Import configuration
from utils import project_non_negative # Assuming project_non_negative is in utils.py

def planner_algorithm1_dual_sgd(reports_u, public_consumptions_b, current_planner_state, eta):
    """
    Implements the planner logic for Algorithm 1 of Balseiro et al. (Dual Subgradient Descent).

    Args:
        reports_u (list): List of reports from K agents.
        public_consumptions_b (list): List of consumption vectors for K agents.
        current_planner_state (dict): {'mu': current_mu, 'B': current_B}.
        eta (float): Step-size for dual update.

    Returns:
        tuple: (allocated_agent_index, payments_vector, consumption_realized,
                next_mu, next_B)
    """
    mu_t_planner = current_planner_state['mu']
    B_t_planner = current_planner_state['B']
    num_agents = len(reports_u)

    # Planner's action space is choosing agent 0..K-1 (or -1 for no one)
    # Planner uses reports u_t,i as f_t(i) and public_consumptions_b[i] as b_t(i)
    lagrangian_values = [] # Value for action '0' (no allocation)
    for i in range(num_agents):
         lagrangian_values.append(reports_u[i] - np.dot(mu_t_planner, public_consumptions_b[i]))
    
    tilde_x_agent_index = np.argmax(lagrangian_values)
    tilde_b = public_consumptions_b[tilde_x_agent_index] # Consumption of the *potential* best action for gradient

    # Actual allocation based on feasibility
    allocated_agent_index = -1
    consumption_realized = np.zeros(config.COST_DIM)
    if tilde_x_agent_index != -1 and np.all(tilde_b <= B_t_planner):
        allocated_agent_index = tilde_x_agent_index
        consumption_realized = tilde_b
    
    # Calculate payments
    payments_vector = np.zeros(num_agents) # No payment at all in Balseiro et al.

    # Update resources
    next_B = B_t_planner - consumption_realized

    # Update dual variables
    rho_vector = config.RHO * np.ones(config.COST_DIM)
    g_t = -tilde_b + rho_vector # Gradient uses consumption of the *potential* action
    mu_next_unprojected = mu_t_planner - eta * g_t
    next_mu = project_non_negative(mu_next_unprojected)

    return allocated_agent_index, payments_vector, consumption_realized, next_mu, next_B


def planner_second_price_auction(reports_u, public_consumptions_b, current_planner_B):
    """
    Implements the planner logic for a second-price auction based on reports.

    Args:
        reports_u (list): List of reports from K agents.
        public_consumptions_b (list): List of consumption vectors for K agents.
        current_planner_B (np.array): Current remaining resources.

    Returns:
        tuple: (allocated_agent_index, payments_vector, consumption_realized, next_B)
    """
    num_agents = len(reports_u)
    allocated_agent_index = -1
    consumption_realized = np.zeros(config.COST_DIM)
    payments_vector = np.zeros(num_agents)

    if not reports_u or num_agents == 0 or not any(r > 0 for r in reports_u):
        next_B = current_planner_B
        return allocated_agent_index, payments_vector, consumption_realized, next_B

    # Sort reports to find highest and second highest
    # Pair reports with original indices and consumptions
    indexed_reports = []
    for i in range(num_agents):
        indexed_reports.append({'report': reports_u[i], 'id': i, 'cost': public_consumptions_b[i]})
    
    sorted_agents = sorted(indexed_reports, key=lambda x: x['report'], reverse=True)

    winner_candidate = sorted_agents[0]
    
    # Check resource feasibility for the highest bidder
    if winner_candidate['report'] > 0 and np.all(winner_candidate['cost'] <= current_planner_B):
        allocated_agent_index = winner_candidate['id']
        consumption_realized = winner_candidate['cost']
        
        # Determine payment (second highest report)
        second_highest_report = 0.0
        if num_agents > 1 and sorted_agents[1]['report'] > 0:
            second_highest_report = sorted_agents[1]['report']
        payments_vector[allocated_agent_index] = second_highest_report
    
    next_B = current_planner_B - consumption_realized
    
    # This planner does not use or update mu, so we don't return it.
    # The simulation loop will handle mu persistence if needed for other parts (like agent state).
    return allocated_agent_index, payments_vector, consumption_realized, next_B

def planner_optimistic_ftrl(reports_u, public_consumptions_b, current_planner_state, eta):
    """
    Implements the planner logic for Algorithm 2 with Eq. (6) (Epoched O-FTRL Mechanism).

    Args:
        reports_u (list): List of reports from K agents.
        public_consumptions_b (list): List of consumption vectors for K agents.
        current_planner_state (dict): {'t': current round, 'mu': current_mu, 'B': current_B,
                                       'history_i': all previous alloc,
                                       'history_u': all previous report,
                                       'history_b': all previous consumptions}.
        eta (float): Step-size for dual update.

    Returns:
        tuple: (allocated_agent_index, payments_vector, consumption_realized,
                next_mu, next_B)
    """
    t = current_planner_state['t']
    mu_t_planner = current_planner_state['mu']
    B_t_planner = current_planner_state['B']
    num_agents = len(reports_u)

    # Calculate the epoch ell that t belongs to (such that 2^ell <= t+1 < 2^{ell+1})
    def calculate_epoch(t):
        if t < 0:
            return -1
        ell = 0
        while (1.5) ** (ell + 1) <= t + 1:
            ell += 1
        return ell
    ell = calculate_epoch(t)

    mu_diff_stats = None
    if calculate_epoch(t - 1) != ell:
        ell += 1
        # Start a new epoch and calculate a new dual
        history_u = current_planner_state['history_u']
        history_b = current_planner_state['history_b']
        history_i = current_planner_state['history_i']
        new_epoch_length = 1.5 ** (ell) - 1.5 ** (ell - 1)
        rho_vector = config.RHO * np.ones(config.COST_DIM)

        # Calculate gradient of all previous losses
        g_history = np.zeros(config.COST_DIM)
        for tau in range(t):
            winner_round_tau = history_i[tau]
            if winner_round_tau != -1:
                g_history += -history_b[tau][winner_round_tau] + rho_vector
            else:
                g_history += rho_vector

        # Estimate gradient of the current-epoch loss
        def g_hat_estimated(mu):
            g_hat = np.zeros(config.COST_DIM)
            for tau in range(t):
                lagrangian_values_under_mu = []
                for i in range(num_agents):
                    lagrangian_values_under_mu.append(history_u[tau][i] - np.dot(mu, history_b[tau][i]))
                winner_under_mu = np.argmax(lagrangian_values_under_mu)
                g_hat += (-history_b[tau][winner_under_mu] + rho_vector) / t
            return g_hat * new_epoch_length

        eta_current = eta / (1.5 ** (ell + 1)) ** 0.5

        # Calculate loss and gradients for scipy
        def objective_loss(mu):
            regularizer = 1 / (2 * eta_current) * np.dot(mu, mu) # Euclidean regularizer
            return np.inner(g_history + g_hat_estimated(mu), mu) + regularizer
        def objective_gradient(mu):
            regularizer_grad = 1 / eta_current * mu
            # Ignoring the dependency of g_hat_est on mu
            # Because according to the paper a small perturbation on mu doesn't change it
            return g_history + g_hat_estimated(mu) + regularizer_grad

        # Call scipy to find the new dual
        bounds = [(0, 2 / rho_vector[i]) for i in range(config.COST_DIM)]

        opt_result = scipy.optimize.minimize(
            objective_loss,
            mu_t_planner,
            method='Nelder-Mead',
            jac=None,
            bounds=bounds,
            options={'maxiter': 200}
        )

        # if not opt_result.success:
        #     # Fallback to the current mu if optimization failed
        #     print(f"Warning: OFTRL optimization for mu_decision did not converge. Message: {opt_result.message}")
        # else:
        #     mu_t_planner = opt_result.x

        def objective_fixed_point(mu):
            return np.linalg.norm(mu + eta_current * (g_history + g_hat_estimated(mu))) ** 2
        fixed_p_result = scipy.optimize.minimize(
            objective_fixed_point,
            mu_t_planner,
            method='Nelder-Mead',
            jac=None,
            bounds=bounds,
            options={'maxiter': 200}
        )
        mu_t_planner = fixed_p_result.x

        # FTRL dual =====
        mu_t_ftrl = np.minimum([1 / rho_vector[i] for i in range(config.COST_DIM)], np.maximum(0,
                                  np.zeros(config.COST_DIM) - eta_current * g_history))
        # fixed-point dual =====
        # mu_t_fixed_point = np.minimum([1 / rho_vector[i] for i in range(config.COST_DIM)], np.maximum(0,
        #                           np.zeros(config.COST_DIM) - eta_current * (g_history + g_hat_estimated(mu_t_planner))))
        mu_diff_stats = (np.linalg.norm(opt_result.x - mu_t_ftrl) / max(np.linalg.norm(opt_result.x), np.linalg.norm(mu_t_ftrl)),
                         np.linalg.norm(opt_result.x - fixed_p_result.x) / max(np.linalg.norm(opt_result.x), np.linalg.norm(fixed_p_result.x)))
        mu_diff_stats = (min(1, mu_diff_stats[0]) if np.isfinite(mu_diff_stats[0]) else 0,
                         min(1, mu_diff_stats[1]) if np.isfinite(mu_diff_stats[1]) else 0)


    tilde_x_agent_index = -1
    tilde_p = 0
    tilde_b = np.zeros(config.COST_DIM) # Consumption of the *potential* best action for gradient

    current_epoch_length = 1.5 ** (ell) - 1.5 ** (ell - 1)
    if np.random.uniform(0, 1) <= 1 / current_epoch_length:
        # Isolation round happens (w.p. 1 / |E_ell| independently)
        winner_candidate = np.random.randint(0, num_agents) # Pick a candidate uniformly
        payment_candidate = np.random.uniform(0, 1) # Pick a payment uniformly

        if reports_u[winner_candidate] >= payment_candidate:
            # Allocate to the random i at price p if u_i >= p
            tilde_x_agent_index = winner_candidate
            tilde_b = public_consumptions_b[tilde_x_agent_index]
            tilde_p = payment_candidate
    else:
        # A standard allocation round. Use sub-routine Algorithm 3

        # Sort reports & costs to find highest and second highest in (v - \lambda^T c)
        indexed_reports = []
        for i in range(num_agents):
            indexed_reports.append({'report': reports_u[i], 'id': i, 'cost': public_consumptions_b[i]})
        sorted_agents = sorted(indexed_reports, key=lambda x: x['report'] - np.dot(mu_t_planner, x['cost']), reverse=True)

        # Allocate to the argmax among dual variable
        tilde_x_agent_index = sorted_agents[0]['id']
        tilde_b = public_consumptions_b[tilde_x_agent_index]
        tilde_p = np.dot(mu_t_planner, sorted_agents[0]['cost']) + \
                  sorted_agents[1]['report'] - \
                  np.dot(mu_t_planner, sorted_agents[1]['cost'])

    # Actual allocation based on feasibility
    allocated_agent_index = -1
    payments_vector = np.zeros(num_agents)
    consumption_realized = np.zeros(config.COST_DIM)
    if tilde_x_agent_index != -1 and np.all(tilde_b <= B_t_planner):
        allocated_agent_index = tilde_x_agent_index
        consumption_realized = tilde_b
        payments_vector[allocated_agent_index] = tilde_p
    
    # Update resources
    next_B = B_t_planner - consumption_realized

    return allocated_agent_index, payments_vector, consumption_realized, mu_t_planner, next_B, mu_diff_stats

def planner_ftrl(reports_u, public_consumptions_b, current_planner_state, eta):
    """
    Implements the planner logic for Algorithm 2 with Eq. (5) (Epoched FTRL Mechanism).

    Args:
        reports_u (list): List of reports from K agents.
        public_consumptions_b (list): List of consumption vectors for K agents.
        current_planner_state (dict): {'t': current round, 'mu': current_mu, 'B': current_B,
                                       'history_i': all previous alloc,
                                       'history_u': all previous report,
                                       'history_b': all previous consumptions}.
        eta (float): Step-size for dual update.

    Returns:
        tuple: (allocated_agent_index, payments_vector, consumption_realized,
                next_mu, next_B)
    """
    t = current_planner_state['t']
    mu_t_planner = current_planner_state['mu']
    B_t_planner = current_planner_state['B']
    num_agents = len(reports_u)
    epoch_length = math.floor(config.T ** (1 / 3))

    # Calculate the epoch ell that t belongs to (such that T^(1/3)*(ell-1)+1 <= t+1 <= T^(2/3)*(ell)
    def calculate_epoch(t):
        if t < 0:
            return -1
        ell = 0
        while epoch_length * ell < t + 1:
            ell += 1
        return ell
    ell = calculate_epoch(t)

    if calculate_epoch(t - 1) != ell:
        # Start a new epoch and calculate a new dual
        history_u = current_planner_state['history_u']
        history_b = current_planner_state['history_b']
        history_i = current_planner_state['history_i']
        rho_vector = config.RHO * np.ones(config.COST_DIM)

        # Calculate gradient of all previous losses
        g_history = np.zeros(config.COST_DIM)
        for tau in range(t):
            winner_round_tau = history_i[tau]
            if winner_round_tau != -1:
                g_history += -history_b[tau][winner_round_tau] + rho_vector
            else:
                g_history += rho_vector

        # Call scipy to find the new dua
        eta_current = eta / (ell * epoch_length * epoch_length) ** 0.5
        mu_t_planner = np.minimum([2 / rho_vector[i] for i in range(config.COST_DIM)], np.maximum(0,
                                  -eta_current * g_history))

    tilde_x_agent_index = -1
    tilde_p = 0
    tilde_b = np.zeros(config.COST_DIM) # Consumption of the *potential* best action for gradient

    if np.random.uniform(0, 1) <= 1 / epoch_length:
        # Isolation round happens (w.p. 1 / |E_ell| independently)
        winner_candidate = np.random.randint(0, num_agents) # Pick a candidate uniformly
        payment_candidate = np.random.uniform(0, 1) # Pick a payment uniformly

        if reports_u[winner_candidate] >= payment_candidate:
            # Allocate to the random i at price p if u_i >= p
            tilde_x_agent_index = winner_candidate
            tilde_b = public_consumptions_b[tilde_x_agent_index]
            tilde_p = payment_candidate
    else:
        # A standard allocation round. Use sub-routine Algorithm 3

        # Sort reports & costs to find highest and second highest in (v - \lambda^T c)
        indexed_reports = []
        for i in range(num_agents):
            indexed_reports.append({'report': reports_u[i], 'id': i, 'cost': public_consumptions_b[i]})
        sorted_agents = sorted(indexed_reports, key=lambda x: x['report'] - np.dot(mu_t_planner, x['cost']), reverse=True)

        # Allocate to the argmax among dual variable
        tilde_x_agent_index = sorted_agents[0]['id']
        tilde_b = public_consumptions_b[tilde_x_agent_index]
        tilde_p = np.dot(mu_t_planner, sorted_agents[0]['cost']) + \
                  sorted_agents[1]['report'] - \
                  np.dot(mu_t_planner, sorted_agents[1]['cost'])

    # Actual allocation based on feasibility
    allocated_agent_index = -1
    payments_vector = np.zeros(num_agents)
    consumption_realized = np.zeros(config.COST_DIM)
    if tilde_x_agent_index != -1 and np.all(tilde_b <= B_t_planner):
        allocated_agent_index = tilde_x_agent_index
        consumption_realized = tilde_b
        payments_vector[allocated_agent_index] = tilde_p
    
    # Update resources
    next_B = B_t_planner - consumption_realized

    return allocated_agent_index, payments_vector, consumption_realized, mu_t_planner, next_B
